import math
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F


def get_timestep_embedding(timesteps, embedding_dim):
    """
    This matches the implementation in Denoising Diffusion Probabilistic Models:
    From Fairseq.
    Build sinusoidal embeddings.
    This matches the implementation in tensor2tensor, but differs slightly
    from the description in Section 3.5 of "Attention Is All You Need".
    """
    assert len(timesteps.shape) == 1

    half_dim = embedding_dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
    emb = emb.to(device=timesteps.device)
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:  # zero pad
        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
    return emb


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1) # max_len x 1 x d_model

        self.register_buffer('pe', pe)
        # self.pe = pe

    def forward(self, x):
        # not used in the final model
        x = x + self.pe[:x.shape[0], :]
        return self.dropout(x)


class TimestepEmbedder(nn.Module):
    def __init__(self, latent_dim, sequence_pos_encoder):
        super().__init__()
        self.latent_dim = latent_dim
        self.sequence_pos_encoder = sequence_pos_encoder

        time_embed_dim = self.latent_dim
        self.time_embed = nn.Sequential(
            nn.Linear(self.latent_dim, time_embed_dim),
            nn.SiLU(),
            nn.Linear(time_embed_dim, time_embed_dim),
        )

    def forward(self, timesteps):
        # > seq_len x 1 x d_model #
        return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2)


def nonlinearity(x):
    # swish
    return x*torch.sigmoid(x)


def Normalize(in_channels):
    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)


class Upsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        super().__init__()
        self.with_conv = with_conv
        if self.with_conv:
            self.conv = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def forward(self, x):
        x = torch.nn.functional.interpolate(
            x, scale_factor=2.0, mode="nearest")
        if self.with_conv:
            x = self.conv(x)
        return x


class Downsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        super().__init__()
        self.with_conv = with_conv
        if self.with_conv:
            # no asymmetric padding in torch conv, must do it ourselves
            self.conv = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=3,
                                        stride=2,
                                        padding=0)

    def forward(self, x):
        if self.with_conv:
            pad = (0, 1, 0, 1)
            
            x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
            
            x = self.conv(x)
        else:
            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
        return x


class ResnetBlock(nn.Module):
    def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
                 dropout, temb_channels=512, cond_channels=-1):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut

        self.norm1 = Normalize(in_channels)
        self.conv1 = torch.nn.Conv2d(in_channels,
                                     out_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)
        
        
        self.temb_proj = torch.nn.Linear(temb_channels,
                                         out_channels)
        
        if cond_channels > 0:
            self.cond_proj = torch.nn.Linear(cond_channels,
                                             out_channels)
        
        
        self.norm2 = Normalize(out_channels)
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = torch.nn.Conv2d(out_channels,
                                     out_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)
        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                self.conv_shortcut = torch.nn.Conv2d(in_channels,
                                                     out_channels,
                                                     kernel_size=3,
                                                     stride=1,
                                                     padding=1)
            else:
                self.nin_shortcut = torch.nn.Conv2d(in_channels,
                                                    out_channels,
                                                    kernel_size=1,
                                                    stride=1,
                                                    padding=0)

    def forward(self, x, temb, cond=None):
        h = x
        h = self.norm1(h)
        h = nonlinearity(h)
        h = self.conv1(h)

        h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
        if cond is not None:
            h = h + self.cond_proj(nonlinearity(cond))[:, :, None, None]

        h = self.norm2(h)
        h = nonlinearity(h)
        h = self.dropout(h)
        h = self.conv2(h)

        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                x = self.conv_shortcut(x)
            else:
                x = self.nin_shortcut(x)

        return x+h


class AttnBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = Normalize(in_channels)
        self.q = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.k = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.v = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.proj_out = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)

    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention
        b, c, h, w = q.shape
        q = q.reshape(b, c, h*w)
        q = q.permute(0, 2, 1)   # b,hw,c
        k = k.reshape(b, c, h*w)  # b,c,hw
        w_ = torch.bmm(q, k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
        w_ = w_ * (int(c)**(-0.5))
        w_ = torch.nn.functional.softmax(w_, dim=2)

        # attend to values
        v = v.reshape(b, c, h*w)
        w_ = w_.permute(0, 2, 1)   # b,hw,hw (first hw of k, second of q)
        # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
        h_ = torch.bmm(v, w_)
        h_ = h_.reshape(b, c, h, w)

        h_ = self.proj_out(h_)

        return x+h_


class ModelDiffResMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        self.diffusion_rep = self.config.model.diffusion_rep # 
        self.res_blocks = self.config.model.res_blocks
        self.res_hidden_dim = self.config.model.res_hidden_dim
        
        self.num_timesteps = self.config.diffusion.num_diffusion_timesteps
        self.separate_encodings = self.config.diffusion.separate_encodings
        
        self.ts_embedding_dim = 256
        self.cond_embedding_dim = 256
        self.input_embedding_dim = 512
        
        # self.timestep_positional_encoder = PositionalEncoding(self.timestep_embedding_dim, 0.1, max_len=self.num_timesteps)
        # self.timestep_positional_encoder = self.timestep_positional_encoder.to(self.device)
        # self.timestep_embedder = TimestepEmbedder(self.timestep_embedding_dim, self.timestep_positional_encoder)
        # self.timestep_embedder = self.timestep_embedder.to(self.device)
            
        self.timestep_positional_encoder = PositionalEncoding(self.ts_embedding_dim, 0.1, max_len=self.num_timesteps)
        self.timestep_embedder = TimestepEmbedder(self.ts_embedding_dim, self.timestep_positional_encoder)
        
        # 
        
        if self.diffusion_rep == 'link_motion': # 
            self.input_dim = self.config.invdyn.future_length * 8 * 3
            self.cond_input_dim = self.config.invdyn.history_length * 8 * 3
            self.output_dim = self.config.invdyn.future_length * 8 * 3
        elif self.diffusion_rep == 'qpos_motion': # 
            self.input_dim = self.config.invdyn.future_length * 16
            self.cond_input_dim = self.config.invdyn.history_length * 16
            self.output_dim = self.config.invdyn.future_length * 16
        elif self.diffusion_rep == 'qposqtars_motion':
            self.input_dim = self.config.invdyn.future_length * (16)
            self.cond_input_dim = self.config.invdyn.history_length * (16 + 16)
            self.output_dim = self.config.invdyn.future_length * 16
        else: 
            raise NotImplementedError
            
        
        self.input_processing_layers = nn.Sequential(
            nn.Linear(self.input_dim, self.input_embedding_dim // 2), nn.ReLU(),
            nn.Linear(self.input_embedding_dim // 2, self.input_embedding_dim)
        )
        self.cond_input_processing_layers = nn.Sequential(
            nn.Linear(self.cond_input_dim, self.cond_embedding_dim), nn.ReLU(),
            nn.Linear(self.cond_embedding_dim, self.cond_embedding_dim)
        )
        self.cat_feature_processing_layers = nn.Sequential(
            nn.Linear(self.ts_embedding_dim + self.cond_embedding_dim + self.input_embedding_dim, self.res_hidden_dim)
        )
        
        self.layers = nn.ModuleList()
        # self.layers.append(nn.Linear(self.invdyn_model_in_dim, self.res_hidden_dim))
        
        # Residual blocks
        for _ in range(self.res_blocks):
            self.layers.append(ResidualMLPBlock(self.res_hidden_dim, self.res_hidden_dim, self.res_hidden_dim))

        self.fc_out = nn.Linear(self.res_hidden_dim, self.output_dim)
        
        if self.separate_encodings:
            self.cond_decodr_layers = nn.Sequential(
                nn.Linear(self.cond_embedding_dim, self.cond_embedding_dim),
                nn.ReLU(),
                nn.Linear(self.cond_embedding_dim, self.cond_input_dim)
                )
            # get the cond enoder layers --- after that you can use #
        
        ##### 
        
        
    def forward(self, x, t, cond):
        x_embedding = self.input_processing_layers(x)
        
        t = t.long()
        t_embedding = self.timestep_embedder(t).squeeze(0)
        
        cond_embedding = self.cond_input_processing_layers(cond)
        
        if self.separate_encodings:
            reconstructed_cond_input = self.cond_decodr_layers(cond_embedding)
            cond_recon_diff = torch.sum(
                ( reconstructed_cond_input - cond  ) **2, dim=-1
            )
            cond_recon_diff = cond_recon_diff.mean() 
            self.cond_recon_diff = cond_recon_diff
            cond_embedding = cond_embedding.detach()
        # print(f"x_embedding: {x_embedding.shape}, cond_embedding: {cond_embedding.shape}, t_embedding: {t_embedding.shape}")
        
        concat_embedding = torch.cat(
            [ x_embedding, cond_embedding, t_embedding ], dim=-1
        )
        
        h = self.cat_feature_processing_layers(concat_embedding)
        
        for layer in self.layers:
            h = layer(h)
        
        pred_noise = self.fc_out(h)
        
        return pred_noise

class ModelSimpleMLP(nn.Module):
    def __init__(self, config): 
        super().__init__()
        self.config = config
        ch, out_ch, ch_mult = config.model.ch, config.model.out_ch, tuple(config.model.ch_mult)
        num_res_blocks = config.model.num_res_blocks
        attn_resolutions = config.model.attn_resolutions
        dropout = config.model.dropout
        in_channels = config.model.in_channels
        resolution = config.data.image_size
        resamp_with_conv = config.model.resamp_with_conv
        num_timesteps = config.diffusion.num_diffusion_timesteps
        self.num_timesteps = num_timesteps
        
        if config.model.type == 'bayesian':
            self.logvar = nn.Parameter(torch.zeros(num_timesteps))
        
        self.ch = ch
        self.temb_ch = self.ch*4
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels

        # timestep embedding # timestep embedding #
        self.temb = nn.Module()
        self.temb.dense = nn.ModuleList([
            torch.nn.Linear(self.ch,
                            self.temb_ch),
            torch.nn.Linear(self.temb_ch,
                            self.temb_ch),
        ])
        
        self.motion_input_dim = 2 * 8 * 3 ##### T * nn_keys * pos_dim 
        self.latent_dim = 128
        self.motion_input_processing_layer = nn.Sequential(
            nn.Linear(self.motion_input_dim, self.latent_dim),
            nn.ReLU(),
            nn.Linear(self.latent_dim, self.latent_dim),
            # nn.ReLU(),
            # nn.Linear(512, 512),
            # nn.ReLU(),
        )
        self.history_cond_dim = 4 * 8 * 3
        self.history_input_layer = nn.Sequential(
            nn.Linear(self.history_cond_dim, self.latent_dim),
            nn.ReLU(),
            nn.Linear(self.latent_dim, self.latent_dim),
            # nn.ReLU(),
            # nn.Linear(512, 512),
            # nn.ReLU(),
        )
        self.timestep_positional_encoder = PositionalEncoding(self.latent_dim, 0.1, max_len=num_timesteps)
        self.timestep_embedder = TimestepEmbedder(self.latent_dim, self.timestep_positional_encoder)
        
        self.denoiser = nn.Sequential(
            nn.Linear(self.latent_dim * 3, self.latent_dim * 2), nn.ReLU(),
            nn.Linear(self.latent_dim * 2, self.latent_dim), nn.ReLU(),
            nn.Linear(self.latent_dim, self.motion_input_dim)
            # , nn.ReLU(),
        )
        # self.p
    
    
    def forward(self, x, t, cond):
        
        if len(x.size()) == 4:
            # x_ori_size = x.size()[1:]
            x = x.contiguous().permute(0, 2, 3, 1).contiguous()
            x_ori_size = x.size()[1:]
            x= x.view(x.size(0), -1).contiguous()
            cond = cond.contiguous().permute(0, 2, 3, 1).contiguous().view(cond.size(0), -1).contiguous()
        else:
            x_ori_size = None
        
        x_embedding = self.motion_input_processing_layer(x)
        # if isinstance(t, torch.FloatTensor):
        #     t = t * self.num_timesteps
        #     t = t.long()
        #     t_embedding = self.timestep_embedder(t)
        # else:
        t = t.long()
        t_embedding = self.timestep_embedder(t).squeeze(0)
            
        cond_embedding = self.history_input_layer(cond)
        
        concat_embedding = torch.cat(
            [ x_embedding, cond_embedding, t_embedding ], dim=-1
        )
        pred_noise = self.denoiser(concat_embedding)
        
        if x_ori_size is not None:
            pred_noise = pred_noise.contiguous().view(-1, *x_ori_size).contiguous().permute(0, 3, 1, 2).contiguous()
            # pred_noise = torch.cat(
                
            # )
        
        return pred_noise
        

class ResidualMLPBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(ResidualMLPBlock, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.shortcut = nn.Linear(input_dim, output_dim)  # Skip connection

    def forward(self, x):
        # Main path
        out = F.relu(self.fc1(x))
        out = self.fc2(out)

        # Skip connection
        shortcut = self.shortcut(x)

        # Add residual
        out += shortcut
        out = F.relu(out)  # Activation after addition
        return out




class ProprioAdaptTConv(nn.Module):
    def __init__(self, ):
        super(ProprioAdaptTConv, self).__init__()
        self.channel_transform = nn.Sequential(
            nn.Linear(16 + 16, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 32),
            nn.ReLU(inplace=True),
        )
        self.temporal_aggregation = nn.Sequential(
            nn.Conv1d(32, 32, (9,), stride=(2,)),
            nn.ReLU(inplace=True),
            nn.Conv1d(32, 32, (5,), stride=(1,)),
            nn.ReLU(inplace=True),
            nn.Conv1d(32, 32, (5,), stride=(1,)),
            nn.ReLU(inplace=True),
        )
        self.low_dim_proj = nn.Linear(32 * 3, 8)

    def forward(self, x):
        x = self.channel_transform(x)  # (N, 50, 32)
        x = x.permute((0, 2, 1))  # (N, 32, 50)
        x = self.temporal_aggregation(x)  # (N, 32, 3)
        # print(f"x: {x.size()}")
        x = self.low_dim_proj(x.flatten(1))
        return x
    
    

class TemporalConv(nn.Module):
    def __init__(self, hist_len=30):
        super(TemporalConv, self).__init__()
        self.hist_len = hist_len
        if hist_len > 20:
            self.channel_transform = nn.Sequential(
                nn.Linear(16 + 16, 32),
                nn.ReLU(inplace=True),
                nn.Linear(32, 32),
                nn.ReLU(inplace=True),
            )
            self.temporal_aggregation = nn.Sequential(
                nn.Conv1d(32, 32, (9,), stride=(2,)),
                nn.ReLU(inplace=True),
                nn.Conv1d(32, 32, (5,), stride=(1,)),
                nn.ReLU(inplace=True),
                nn.Conv1d(32, 32, (5,), stride=(1,)),
                nn.ReLU(inplace=True),
            )
            out_size = (hist_len - 9) // 2 + 1
            out_size = (out_size - 5) + 1
            out_size = (out_size - 5) + 1
            self.low_dim_proj = nn.Linear(32 * out_size, 8)
        else:
            self.low_dim_proj = nn.Sequential(
                nn.Linear(32 * hist_len, 128),
                nn.ReLU(inplace=True),
                nn.Linear(128, 8),
            )
            # nn.Linear(32 * hist_len, 8)

    def forward(self, x):
        if self.hist_len > 20:
            x = self.channel_transform(x)  # (N, 50, 32)
            x = x.permute((0, 2, 1))  # (N, 32, 50)
            x = self.temporal_aggregation(x)  # (N, 32, 3)
            # print(f"x: {x.size()}")
            x = self.low_dim_proj(x.flatten(1))
        else:
            x = x.contiguous().view(-1, self.hist_len * 32)
            x = self.low_dim_proj(x) # transform the history latent input to latent features #
        return x



class WorldModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        self.finger_idx = self.config.invdyn.finger_idx
        self.joint_idx = self.config.invdyn.joint_idx
        self.wm_history_length = self.config.invdyn.wm_history_length
        self.hist_context_length = self.config.invdyn.hist_context_length # 0 is without the history context, other lengths 
        
        self.res_hidden_dim = self.config.invdyn.res_hidden_dim
        self.res_blocks = self.config.invdyn.res_blocks
        
        self.hist_context_finger_idx = self.config.invdyn.hist_context_finger_idx
        self.wm_pred_joint_idx = self.config.invdyn.wm_pred_joint_idx
        self.add_nearing_neighbour = self.config.invdyn.add_nearing_neighbour # 
        self.add_nearing_finger = self.config.invdyn.add_nearing_finger
        
        # per joint single wm -- #
        self.multi_joint_single_wm = self.config.invdyn.multi_joint_single_wm
        self.multi_finger_single_wm = self.config.invdyn.multi_finger_single_wm
        self.single_hand_wm = self.config.invdyn.single_hand_wm
        
        self.multi_joint_single_shared_wm = self.config.invdyn.multi_joint_single_shared_wm #
        self.fullhand_wobjstate_wm = self.config.invdyn.fullhand_wobjstate_wm # full hand wm #
        
        if self.joint_idx >= 0:
            if self.add_nearing_finger:
                self.model_input_dim = (4 * self.wm_history_length) * 2
            else:
                self.model_input_dim = (1 * self.wm_history_length) * 2
                if self.add_nearing_neighbour:
                    self.model_input_dim = self.model_input_dim + ((1 + 1) * 2)
            self.model_output_dim = 1
        elif self.finger_idx >= 0:
            self.model_input_dim = (4 * self.wm_history_length) * 2
            self.model_output_dim = 4
            print(f"Using finger idx: {self.finger_idx}, model_input_dim: {self.model_input_dim}, model_output_dim: {self.model_output_dim}")
        else:
            if self.fullhand_wobjstate_wm:
                self.model_input_dim = (16 * self.wm_history_length) * 2 + 4
                self.model_output_dim = 16 + 4
            else:
                self.model_input_dim = (16 * self.wm_history_length) * 2
                self.model_output_dim = 16
        
        if self.wm_pred_joint_idx >= 0:
            self.model_output_dim = 1
            
        if self.hist_context_length > 0:
            hist_context_input_dim = self.hist_context_length * (32)
            
            if self.hist_context_finger_idx >= 0:
                hist_context_input_dim = self.hist_context_length * (4 + 4)
            
            self.hist_layers = nn.ModuleList()
            self.hist_layers.append(nn.Linear(hist_context_input_dim, self.res_hidden_dim))
            for _ in range(self.res_blocks):
                self.hist_layers.append(ResidualMLPBlock(self.res_hidden_dim, self.res_hidden_dim, self.res_hidden_dim))
            self.hist_fc_out = nn.Linear(self.res_hidden_dim, self.res_hidden_dim)
            self.model_input_dim = self.model_input_dim + self.res_hidden_dim
        
        
        
        if self.multi_joint_single_wm:
            
            if self.multi_joint_single_shared_wm:
                self.per_joint_wm_input_dim = (1 * self.wm_history_length) * 2
                self.per_joint_wm_output_dim = 1
                self.joint_shared_wm_layers = nn.ModuleList()
                self.joint_shared_wm_layers.append(nn.Linear(self.per_joint_wm_input_dim, self.res_hidden_dim))
                for _ in range(self.res_blocks):
                    self.joint_shared_wm_layers.append(ResidualMLPBlock(self.res_hidden_dim, self.res_hidden_dim, self.res_hidden_dim))
                self.joint_shared_wm_layers.append(nn.Linear(self.res_hidden_dim, self.per_joint_wm_output_dim))
            else:
                self.nn_joints = 16
                self.per_joint_wm_layers = nn.ModuleList()
                self.per_joint_wm_input_dim = (1 * self.wm_history_length) * 2
                self.per_joint_wm_output_dim = 1
                for i_joint in range(self.nn_joints):
                    cur_joint_wm_layers = nn.ModuleList()
                    cur_joint_wm_layers.append(nn.Linear(self.per_joint_wm_input_dim, self.res_hidden_dim))
                    for _ in range(self.res_blocks):
                        cur_joint_wm_layers.append(ResidualMLPBlock(self.res_hidden_dim, self.res_hidden_dim, self.res_hidden_dim))
                    cur_joint_wm_layers.append(nn.Linear(self.res_hidden_dim, self.per_joint_wm_output_dim))
                    self.per_joint_wm_layers.append(cur_joint_wm_layers)
        elif self.multi_finger_single_wm:
            self.nn_fingers = 4
            self.per_finger_wm_layers = nn.ModuleList()
            self.per_finger_wm_input_dim = (4 * self.wm_history_length) * 2
            self.per_finger_wm_output_dim = 4
            for i_finger in range(self.nn_fingers):
                cur_finger_wm_layers = nn.ModuleList()
                cur_finger_wm_layers.append(nn.Linear(self.per_finger_wm_input_dim, self.res_hidden_dim))
                for _ in range(self.res_blocks):
                    cur_finger_wm_layers.append(ResidualMLPBlock(self.res_hidden_dim, self.res_hidden_dim, self.res_hidden_dim))
                cur_finger_wm_layers.append(nn.Linear(self.res_hidden_dim, self.per_finger_wm_output_dim))
                self.per_finger_wm_layers.append(cur_finger_wm_layers)
        elif self.single_hand_wm:
            # self.hand_input_dim = (16 * self.wm_history_length) * 2 + 4
            # self.hand_output_dim = 16 + 4
            
            self.hand_input_dim = (16 * self.wm_history_length) * 2 # + 4
            self.hand_output_dim = 16 # + 4
            self.hand_wm_layers = nn.ModuleList()
            self.hand_wm_layers.append(nn.Linear(self.hand_input_dim, self.res_hidden_dim))
            for _ in range(self.res_blocks):
                self.hand_wm_layers.append(ResidualMLPBlock(self.res_hidden_dim, self.res_hidden_dim, self.res_hidden_dim))
            self.hand_wm_layers.append(nn.Linear(self.res_hidden_dim, self.hand_output_dim))
        else:
            self.layers = nn.ModuleList()
            self.layers.append(nn.Linear(self.model_input_dim, self.res_hidden_dim))
            
            # Residual blocks
            for _ in range(self.res_blocks):
                self.layers.append(ResidualMLPBlock(self.res_hidden_dim, self.res_hidden_dim, self.res_hidden_dim))

            self.fc_out = nn.Linear(self.res_hidden_dim, self.model_output_dim)
    
    
    def hist_processing(self, hist_state, hist_action):
        hist_context_input = torch.cat(
             [ hist_state, hist_action ], dim=-1
        )
        for layer in self.hist_layers:
            hist_context_input = layer(hist_context_input)
        hist_context_input = self.hist_fc_out(hist_context_input)
        return hist_context_input
        
    
    def forward(self, input_dict):
        state = input_dict['state']
        action = input_dict['action']
        # nex_state = input_dict['nex_state']
        
        # action = action  - state
        
        if self.multi_joint_single_wm:
            
            if self.multi_joint_single_shared_wm:
                # state: nn_envs x nn_history_length x nn_joints #
                nn_bsz, nn_joints = state.size(0), state.size(-1)
                expanded_state = state.contiguous().permute(0, 2, 1).contiguous()
                expanded_action = action.contiguous().permute(0, 2, 1).contiguous()
                expanded_state = expanded_state.contiguous().view(expanded_state.size(0) * expanded_state.size(1), -1) # (nn_envs x nn_joints, history_length)
                expanded_action = expanded_action.contiguous().view(expanded_action.size(0) * expanded_action.size(1), -1) # (nn_envs x nn_joints, history_length)
                expanded_wm_input = torch.cat(
                    [expanded_state, expanded_action], dim=-1
                )
                for layer in self.joint_shared_wm_layers: 
                    expanded_wm_input = layer(expanded_wm_input) 
                
                pred_nex_state = expanded_wm_input #  self.joint_shared_wm_layers(expanded_wm_input)
                pred_nex_state = pred_nex_state.contiguous().view(nn_bsz, nn_joints)
            else:
                tot_pred_nex_statet = []
                # bsz x nn_history_length x nn_joints # 
                for i_joint in range(self.nn_joints):
                    cur_joint_state = state[..., i_joint]
                    cur_joint_action = action[..., i_joint]
                    cur_joint_wm_input = torch.cat([cur_joint_state, cur_joint_action], dim=-1)
                    for layer in self.per_joint_wm_layers[i_joint]:
                        cur_joint_wm_input = layer(cur_joint_wm_input)
                    cur_joint_wm_input = torch.sigmoid(cur_joint_wm_input) * 2.0 - 1.0 # [-1, 1]
                    tot_pred_nex_statet.append(cur_joint_wm_input)
                pred_nex_state = torch.cat(tot_pred_nex_statet, dim=-1)
            return pred_nex_state
        
        elif self.multi_finger_single_wm:
            tot_pred_nex_state = []
            for i_finger in range(self.nn_fingers):
                cur_finger_joint_idxes = [ 4 * i_finger + ii for ii in range(4) ]
                cur_finger_state = state[..., cur_finger_joint_idxes]
                cur_finger_action = action[..., cur_finger_joint_idxes]
                
                cur_finger_state = cur_finger_state.contiguous().view(cur_finger_state.size(0), -1).contiguous() # (nn_envs x (nn_finger))
                cur_finger_action = cur_finger_action.contiguous().view(cur_finger_action.size(0), -1).contiguous() #
                cur_finger_wm_input = torch.cat([ cur_finger_state, cur_finger_action ], dim=-1)
                
                for layer in self.per_finger_wm_layers[i_finger]:
                    cur_finger_wm_input = layer(cur_finger_wm_input)
                cur_finger_wm_input = torch.sigmoid(cur_finger_wm_input) * 2.0 - 1.0 
                tot_pred_nex_state.append(cur_finger_wm_input)
            pred_nex_state = torch.cat(tot_pred_nex_state, dim=-1)
            return pred_nex_state
        elif self.single_hand_wm:
            # hand_wm_layers
            expanded_state = state.contiguous().view(state.size(0), -1).contiguous()
            expanded_action = action.contiguous().view(action.size(0), -1).contiguous()
            hand_wm_input = torch.cat([expanded_state, expanded_action], dim=-1)
            for layer in self.hand_wm_layers:
                hand_wm_input = layer(hand_wm_input)
            # pred_nex_state = self.hand_wm_layers(hand_wm_input)
            return hand_wm_input   
        elif self.fullhand_wobjstate_wm:
            expanded_state = state.contiguous().view(state.size(0), -1).contiguous()
            expanded_action = action.contiguous().view(action.size(0), -1).contiguous()
            hand_wm_input = torch.cat([expanded_state, expanded_action], dim=-1)
            for layer in self.hand_wm_layers:
                hand_wm_input = layer(hand_wm_input)
            # pred_nex_state = self.hand_wm_layers(hand_wm_input)
            hand_pred, obj_pred = hand_wm_input[..., :16], hand_wm_input[..., 16:]
            obj_pred = obj_pred / torch.clamp(torch.norm(obj_pred, dim=-1, p=2, keepdims=True), min=1e-8)
            hand_wm_input = torch.cat([hand_pred, obj_pred], dim=-1)
            return hand_wm_input   
        else:
            model_input = torch.cat([state, action - state], dim=-1)
            
            if self.hist_context_length > 0:
                hist_state = input_dict['hist_state']
                hist_action = input_dict['hist_action']
                hist_context_input = self.hist_processing(hist_state, hist_action)
                model_input = torch.cat([model_input, hist_context_input], dim=-1)
            
            
            for layer in self.layers:
                model_input = layer(model_input)
            pred_nex_state = self.fc_out(model_input)
            
            return pred_nex_state

class QValueModel(nn.Module):
    def __init__(self, config= None):
        super().__init__()
        self.hist_window_length = 10
        self.state_dim = 16

        self.state_input_dim = (self.state_dim * 2) * self.hist_window_length 
        self.action_input_dim = self.state_dim

        self.l1 = nn.Linear(self.state_input_dim + self.action_input_dim, 400)
        self.l2 = nn.Linear(400, 300)
        self.l3 = nn.Linear(300, 1)

    def forward(self, state, action):
        q = F.relu(self.l1(torch.cat([state, action], 1)))
        q = F.relu(self.l2(q))
        return self.l3(q)

class VModel(nn.Module):
    def __init__(self, config= None):
        super().__init__()
        self.config = config
        
        self.hist_window_length = 10
        self.state_dim = 16
        
        self.state_input_dim = (self.state_dim * 2) * self.hist_window_length 
        
        self.l1 = nn.Linear(self.state_input_dim, 400)
        self.l2 = nn.Linear(400, 300)
        self.l3 = nn.Linear(300, 1)

    def forward(self, state):
        v = F.relu(self.l1(state))
        v = F.relu(self.l2(v))
        return self.l3(v)



class WorldModelDeltaActions(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        self.finger_idx = self.config.invdyn.finger_idx
        self.joint_idx = self.config.invdyn.joint_idx
        self.wm_history_length = self.config.invdyn.wm_history_length
        self.hist_context_length = self.config.invdyn.hist_context_length
        self.res_hidden_dim = self.config.invdyn.res_hidden_dim
        self.res_blocks = self.config.invdyn.res_blocks
        self.hist_context_finger_idx = self.config.invdyn.hist_context_finger_idx
        
        if self.joint_idx >= 0:
            self.model_input_dim = (1 * self.wm_history_length) * 2
            self.model_output_dim = 1
        elif self.finger_idx >= 0:
            self.model_input_dim = (4 * self.wm_history_length) * 2
            self.model_output_dim = 4
        else:
            self.model_input_dim = (16 * self.wm_history_length) * 2
            self.model_output_dim = 16
        
        print(f"model_input_dim: {self.model_input_dim}, model_output_dim: {self.model_output_dim}")
        
        if self.hist_context_length > 0:
            hist_context_input_dim = self.hist_context_length * (32)
            
            if self.hist_context_finger_idx >= 0:
                hist_context_input_dim = self.hist_context_length * (4 + 4)
            
            self.hist_layers = nn.ModuleList()
            self.hist_layers.append(nn.Linear(hist_context_input_dim, self.res_hidden_dim))
            for _ in range(self.res_blocks):
                self.hist_layers.append(ResidualMLPBlock(self.res_hidden_dim, self.res_hidden_dim, self.res_hidden_dim))
            self.hist_fc_out = nn.Linear(self.res_hidden_dim, self.res_hidden_dim)
            
            self.model_input_dim = self.model_input_dim + self.res_hidden_dim
        
        
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(self.model_input_dim, self.res_hidden_dim))
        
        # Residual blocks
        for _ in range(self.res_blocks):
            self.layers.append(ResidualMLPBlock(self.res_hidden_dim, self.res_hidden_dim, self.res_hidden_dim))

        self.fc_out = nn.Linear(self.res_hidden_dim, self.model_output_dim)
    
    
    def hist_processing(self, hist_state, hist_action):
        hist_context_input = torch.cat(
             [ hist_state, hist_action ], dim=-1
        )
        for layer in self.hist_layers:
            hist_context_input = layer(hist_context_input)
        hist_context_input = self.hist_fc_out(hist_context_input)
        return hist_context_input
    
        
    def forward(self, input_dict):
        state = input_dict['state']
        action = input_dict['action']
        # nex_state = input_dict['nex_state']
        
        model_input = torch.cat([state, action], dim=-1)
        
        if self.hist_context_length > 0:
            hist_state = input_dict['hist_state']
            hist_action = input_dict['hist_action']
            hist_context_input = self.hist_processing(hist_state, hist_action)
            model_input = torch.cat([model_input, hist_context_input], dim=-1)
        
        
        for layer in self.layers:
            model_input = layer(model_input)
        pred_delta_actions = self.fc_out(model_input)
        
        return pred_delta_actions



class ModelInvDyn(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        
        self.history_length = self.config.invdyn.history_length
        self.history_obs_dim = self.config.invdyn.history_obs_dim
        self.future_length = self.config.invdyn.future_length
        self.future_ref_dim = self.config.invdyn.future_ref_dim
        self.future_act_dim = self.config.invdyn.future_act_dim
        
        self.invdyn_model_in_dim = self.history_length * self.history_obs_dim + self.future_length * self.future_ref_dim
        self.invdyn_model_out_dim = self.future_act_dim * self.future_length
        
        self.invdyn_hist_context_length = self.config.invdyn.hist_context_length
        
        self.train_value_network = self.config.invdyn.train_value_network
        if self.train_value_network:
            self.invdyn_model_in_dim = self.history_length * self.history_obs_dim + 16 
            self.invdyn_model_out_dim = 1
        
        try:
            self.pred_extrin = self.config.invdyn.pred_extrin
        except:
            self.pred_extrin = False
        # self.invdyn_input_type == 'hand_qpos_qtars'
        # history obs dim -- 16 * 16 #
        
        if self.invdyn_hist_context_length > 0:
            self.hist_context_conv = TemporalConv(hist_len=self.invdyn_hist_context_length)
            self.invdyn_model_in_dim = self.invdyn_model_in_dim + 8
        
        if self.pred_extrin:
            self.hist_extrin_pred_layers = ProprioAdaptTConv()
            self.extrin_low_dim = 8
            self.invdyn_model_in_dim += self.extrin_low_dim
        
        self.model_arch = self.config.invdyn.model_arch
        
        self.invdyn_w_hand_root_ornt = self.config.invdyn.w_hand_root_ornt
        
        if self.invdyn_w_hand_root_ornt:
            self.invdyn_model_in_dim = self.invdyn_model_in_dim + 4
        
        
        
        if self.model_arch == 'mlp':
            self.invdyn_model = nn.Sequential(
                nn.Linear(self.invdyn_model_in_dim, 512), nn.ReLU(),
                nn.Linear(512, 512), nn.ReLU(),
                nn.Linear(512, self.invdyn_model_out_dim)
            )
        elif self.model_arch == 'resmlp':
            
            self.res_hidden_dim = self.config.invdyn.res_hidden_dim
            self.res_blocks = self.config.invdyn.res_blocks
            self.layers = nn.ModuleList()
            self.layers.append(nn.Linear(self.invdyn_model_in_dim, self.res_hidden_dim))
            
            # Residual blocks
            for _ in range(self.res_blocks):
                self.layers.append(ResidualMLPBlock(self.res_hidden_dim, self.res_hidden_dim, self.res_hidden_dim))

            self.fc_out = nn.Linear(self.res_hidden_dim, self.invdyn_model_out_dim)
        
        elif self.model_arch == 'resmlp_gaussian':
            self.res_hidden_dim = self.config.invdyn.res_hidden_dim
            self.res_blocks = self.config.invdyn.res_blocks
            
            self.mu_layers = nn.ModuleList()
            self.mu_layers.append(nn.Linear(self.invdyn_model_in_dim, self.res_hidden_dim))
            for _ in range(self.res_blocks):
                self.mu_layers.append(ResidualMLPBlock(self.res_hidden_dim, self.res_hidden_dim, self.res_hidden_dim))
            self.mu_fc_out = nn.Linear(self.res_hidden_dim, self.invdyn_model_out_dim)
            
            self.sigma_layers = nn.ModuleList()
            self.sigma_layers.append(nn.Linear(self.invdyn_model_in_dim, self.res_hidden_dim))
            for _ in range(self.res_blocks):
                self.sigma_layers.append(ResidualMLPBlock(self.res_hidden_dim, self.res_hidden_dim, self.res_hidden_dim))
            self.sigma_fc_out = nn.Linear(self.res_hidden_dim, self.invdyn_model_out_dim)
            
        
        elif self.model_arch == 'resmlp_atten':
            self.res_hidden_dim = self.config.invdyn.res_hidden_dim
            self.res_blocks = self.config.invdyn.res_blocks
            
            self.num_heads = 4
            self.ff_size = self.res_hidden_dim * 4
            self.dropout = 0.1
            self.activation = 'gelu'
            self.atten_nn_layers = 2
            
            self.res_model_in_dim = self.history_obs_dim
            self.invdyn_model_out_dim = self.future_act_dim
            self.layers = nn.ModuleList()
            self.layers.append(nn.Linear(self.res_model_in_dim, self.res_hidden_dim))
            
            # Residual blocks
            for _ in range(self.res_blocks):
                self.layers.append(ResidualMLPBlock(self.res_hidden_dim, self.res_hidden_dim, self.res_hidden_dim))
            
            self.hist_positional_encoding = PositionalEncoding(self.res_hidden_dim, self.dropout, max_len=self.history_length + 1)    
            
            seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.res_hidden_dim,
                                                            nhead=self.num_heads,
                                                            dim_feedforward=self.ff_size,
                                                            dropout=self.dropout,
                                                            activation=self.activation)
            self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer,
                                                         num_layers=self.atten_nn_layers)
            
            self.future_layers = nn.ModuleList()
            self.future_layers.append(nn.Linear(self.future_ref_dim, self.res_hidden_dim))
            # Residual blocks
            for _ in range(self.res_blocks):
                self.future_layers.append(ResidualMLPBlock(self.res_hidden_dim, self.res_hidden_dim, self.res_hidden_dim))
            self.future_positional_encoding = PositionalEncoding(self.res_hidden_dim, self.dropout, max_len=self.history_length + 1)   
            
            
            futureTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.res_hidden_dim + self.res_hidden_dim,   
                                                            nhead=self.num_heads,
                                                            dim_feedforward=self.ff_size,
                                                            dropout=self.dropout,
                                                            activation=self.activation)
            
            self.futureTransEncoder = nn.TransformerEncoder(futureTransEncoderLayer, num_layers=self.atten_nn_layers)
            
            # res hidden dim
            self.fc_out = nn.Linear(self.res_hidden_dim + self.res_hidden_dim, self.invdyn_model_out_dim)
            
        
        elif self.model_arch == 'resmlp_moe':
            self.res_hidden_dim = self.config.invdyn.res_hidden_dim
            self.res_blocks = self.config.invdyn.res_blocks
            self.num_experts = getattr(self.config.invdyn, 'moe_num_experts', 4)
            self.k = getattr(self.config.invdyn, 'moe_k', 2)
            self.invdyn_model = ResidualMLPWithMOE(
                input_dim=self.invdyn_model_in_dim,
                hidden_dim=self.res_hidden_dim,
                output_dim=self.invdyn_model_out_dim,
                num_experts=self.num_experts,
                k=self.k,
                num_layers=self.res_blocks
            )
        elif self.model_arch == 'transformer':
            self.transformer_dim = 128
            self.transformer_nhead = 4
            self.transformer_ff = 4 * self.transformer_dim
            self.transformer_layers = 2
            self.transformer_dropout = 0.1
            self.in_proj = nn.Linear(self.invdyn_model_in_dim, self.transformer_dim)
            # self.pos_encoder = PositionalEncoding(self.transformer_dim, self.transformer_dropout, max_len=self.history_length + self.future_length + 8)
            encoder_layer = nn.TransformerEncoderLayer(
                d_model=self.transformer_dim,
                nhead=self.transformer_nhead,
                dim_feedforward=self.transformer_ff,
                dropout=self.transformer_dropout,
                activation='gelu'
            )
            self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=self.transformer_layers)
            self.out_proj = nn.Linear(self.transformer_dim, self.invdyn_model_out_dim)
        elif self.model_arch == 'transformer_moe':
            self.transformer_dim = 128
            self.transformer_nhead = 4
            self.transformer_ff = 4 * self.transformer_dim
            self.transformer_layers = 2
            self.transformer_dropout = 0.1
            self.in_proj = nn.Linear(self.invdyn_model_in_dim, self.transformer_dim)
            self.pos_encoder = PositionalEncoding(self.transformer_dim, self.transformer_dropout, max_len=self.history_length + self.future_length + 8)
            encoder_layer = nn.TransformerEncoderLayer(
                d_model=self.transformer_dim,
                nhead=self.transformer_nhead,
                dim_feedforward=self.transformer_ff,
                dropout=self.transformer_dropout,
                activation='gelu'
            )
            self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=self.transformer_layers)
            # MOE after transformer
            self.moe_hidden_dim = 128
            self.moe_output_dim = 128
            self.num_experts = getattr(self.config.invdyn, 'moe_num_experts', 4)
            self.k = getattr(self.config.invdyn, 'moe_k', 2)
            self.moe_layer = MOELayer(self.transformer_dim, self.moe_hidden_dim, self.moe_output_dim, self.num_experts, self.k)
            self.out_proj = nn.Linear(self.moe_output_dim, self.invdyn_model_out_dim)
        else:
            raise ValueError(f"Invalid model architecture: {self.model_arch}")
    
    
    
    def forward(self, history_obs, history_ref, history_extrin=None, hist_context=None):
        if self.model_arch in ['mlp', 'resmlp', 'resmlp_gaussian', 'resmlp_moe', 'transformer', 'transformer_moe']:
            invdyn_model_input = torch.cat(
                [ history_obs, history_ref], dim=-1 # nn_bsz x input_feature_dims
            )
        
        if self.invdyn_hist_context_length > 0:
            hist_context_input = self.hist_context_conv(hist_context)
            invdyn_model_input = torch.cat(
                [ invdyn_model_input, hist_context_input ], dim=-1
            )
        
        if self.pred_extrin:
            extrin_pred = self.hist_extrin_pred_layers(history_extrin)
            extrin_pred = torch.tanh(extrin_pred)
            self.extrin_pred = extrin_pred
            invdyn_model_input = torch.cat(
                [ invdyn_model_input, extrin_pred.detach()], dim=-1
            )
        
        if self.model_arch == 'mlp':
            invdyn_model_output = self.invdyn_model(invdyn_model_input)
            
        elif self.model_arch == 'resmlp':
            for layer in self.layers:
                invdyn_model_input = layer(invdyn_model_input)
            invdyn_model_output = self.fc_out(invdyn_model_input)
        
        elif self.model_arch == 'resmlp_gaussian':
            for i_layer, layer in enumerate(self.mu_layers):
                if i_layer == 0:
                    mu = layer(invdyn_model_input)
                else:
                    mu = layer(mu)
            mu = self.mu_fc_out(mu)
            
            for i_layer, layer in enumerate(self.sigma_layers):
                if i_layer == 0:
                    sigma = layer(invdyn_model_input)
                else:
                    sigma = layer(sigma)
            sigma = self.sigma_fc_out(sigma)
            
            # sigma is log sigma
            # reparameterization -- out = mu + torch.randn_like(sigma) * sigma.exp()
            out = mu + torch.randn_like(sigma) * sigma.exp()
            invdyn_model_output = out
            # add the regularization loss? #
            reg_sigma = -sigma.sum(dim=-1).mean() # wish it describe the distribution so that the sigma should not be too small
            self.reg_sigma = reg_sigma
                    
        
        elif self.model_arch == 'resmlp_atten':
            flatten_history_obs = history_obs.contiguous().view(history_obs.size(0), self.history_length, -1)
            for layer in self.layers:
                flatten_history_obs = layer(flatten_history_obs)
            flatten_history_obs = flatten_history_obs.contiguous().transpose(1, 0).contiguous()
            flatten_history_obs = self.hist_positional_encoding(flatten_history_obs)
            
            flatten_history_feats = self.seqTransEncoder(flatten_history_obs) # nn_hist_ts x nn_bsz x nn_latent_dim
            flatten_history_feats = flatten_history_feats[-1:].contiguous().repeat(self.future_length, 1, 1).contiguous()
            
            flatten_future = history_ref.contiguous().view(history_ref.size(0), self.future_length, -1)
            for layer in self.future_layers:
                flatten_future = layer(flatten_future)
            flatten_future = flatten_future.contiguous().transpose(1, 0).contiguous()
            flatten_future = self.future_positional_encoding(flatten_future)
            # print(f"flatten_history_feats: {flatten_history_feats.size()}, flatten_future: {flatten_future.size()}, res_hidden_dim: {self.res_hidden_dim}, ff_size: {self.ff_size}, num_heads: {self.num_heads}")
            history_future = torch.cat([flatten_history_feats, flatten_future], dim=-1)
            flatten_future = self.futureTransEncoder(history_future)
            flatten_future = flatten_future.contiguous().transpose(1, 0).contiguous()
            future_out = self.fc_out(flatten_future) # nn_bsz x nn_ts x out_dim
            invdyn_model_output = future_out.contiguous().view(future_out.size(0), -1)
                 
        elif self.model_arch == 'resmlp_moe':
            invdyn_model_output = self.invdyn_model(invdyn_model_input)
        elif self.model_arch == 'transformer':
            # [batch, features] -> [seq_len, batch, features] for transformer
            # For this MLP-style input, treat batch as sequence of length 1
            x = self.in_proj(invdyn_model_input).unsqueeze(0)  # [1, batch, d_model]
            # x = self.pos_encoder(x)
            x = self.transformer_encoder(x)  # [1, batch, d_model]
            x = x.squeeze(0)  # [batch, d_model]
            invdyn_model_output = self.out_proj(x)
        elif self.model_arch == 'transformer_moe':
            x = self.in_proj(invdyn_model_input).unsqueeze(0)  # [1, batch, d_model]
            x = self.pos_encoder(x)
            x = self.transformer_encoder(x)  # [1, batch, d_model]
            x = x.squeeze(0)  # [batch, d_model]
            x = self.moe_layer(x)  # [batch, moe_output_dim]
            invdyn_model_output = self.out_proj(x)
        else:
            raise ValueError(f"Invalid model architecture: {self.model_arch}")
        
        return invdyn_model_output

